2  Mixture Distributions and Mixture Models

Published

August 18, 2023

2.1 Review 1

sample from a distribution

2.2 A New Type of Random Variable

Latent Variable

define: latent variable

We learned about common univariate and multivariate distributions. For each of the distributions, there are well-defined and straightfoward ways to sample values from the distribution.

Let’s consider a slightly different “data generation story” where we sample in two stages. (Data generation story from ESL [273])

  1. Sample from a discrete probability distribution with \(k\) unique values (i.e. Bernoulli distribution when \(k = 2\) and categorical distribution when \(k > 2\)).
  2. Sample from one of \(k\) different distributions conditional on the outcome of step 1.

Let’s consider a concrete example with a Bernoulli distribution and two normal distributions.

  1. Sample \(X \sim Bern(p = 0.25)\)
  2. Sample from \(Y \sim N(\mu = 0, \sigma = 2)\) if \(X = 0\) and \(Y \sim (\mu = 4, \sigma = 2)\) if \(X = 1\).

Now, let’s sample from a Bernoulli distribution and then sample from one of two normal distributions using R code.

generate_data <- function(n) {
  
  step1 <- sample(x = c(0, 1), size = n, replace = TRUE, prob = c(0.75, 0.25))
  
  step1 <- sort(step1)
  
  step2 <- c(
    rnorm(n = sum(step1 == 0), mean = 0, sd = 2),
    rnorm(n = sum(step1 == 1), mean = 5, sd = 1)
  )
  
  tibble::tibble(
    x = step1,
    y = step2
  )

}

set.seed(1)

generate_data(n = 1000) %>%
  ggplot(aes(y)) +
  geom_density() #+

  # geom_density(aes(color = factor(x)))

This marginal distribution looks complex but the process of creating the marginal distribution is simple.

This two-step approach dramatically increases the types of distributions at our disposal because we are no longer limited to common univariate distributions like the normal distribution and uniform distribution. The two-step approach is also the foundation of two related tools:

  1. Mixture distributions: Distributions expressed as the linear combination of other distributions. Mixture distributions can be very complicated distributions expressed in terms of simple distributions with known properties.
  2. Mixture models: Statistical inference about sub-populations made only with pooled data without labels for the sub populations.

With mixture distributions, we care about the overall distribution and don’t care about the latent variables.

With mixture modeling, we use the overall distribution to learn about the latent variables/subpopulations/clusters in the data.

ancestral sampling: sample z_hat, sample x_hat|z_hat

2.3 Mixture Distribution

Define mixture density

“complex marginal distributions can be expressed as the joint distribution of observed and latent variables”

Geysers data (Bishop, Hands On ML in R, )

2.4 Review #2

multivariate-normal distribution

Review K-Means Clustering

hard assignments

2.5 Mixture Models/Model-Based Clustering

2.5.1 Gaussian Mixture Modeling (GMM)

parameters: means, covariance matrices, mixture weights/mixture proportions

over parameterized

soft assignments responsibility

challenges: 1. assumes a model 2. computation – overparameterized

library(censusapi)
Warning: package 'censusapi' was built under R version 4.1.2

Attaching package: 'censusapi'
The following object is masked from 'package:methods':

    getFunction
us_youth_sahie <- getCensus(
  name = "timeseries/healthins/sahie",
  vars = c("GEOID", "PCTUI_PT"),
  region = "county:*",
  regionin = "state:*",
  time = 2019,
  AGECAT = 4
)

2.5.2 EM

ESL has a good explanation

2.5.3 BIC

Likelihood: The likelihood is how probable observed data are given a set of parameters. If \(\theta\) is a vector of parameters, then \(L(\theta |x) = f(x |\theta)\) is the likelihood function. The likelihood is not a probability. In fact, we often take the log of the likelihood for computational reasons.

For example, the likelihood of \(x = ?\) is blank given $X_i N(0, 1)

The Bayesian information criterion, or BIC, is a version of the likelihood that penalizes models for having many parameters.

explain Log-Likelihood explain BIC maximize BIC!

2.5.4 Example

Medicaid expansion Geyser data

2.5.5 Bernoulli Mixture Modeling (BMM)

Let’s consider a data generation story based on the Bernoulli distribution. Now, each variable, \(X_1, X_2, ..., X_D\), is draw from a mixture of \(K\) Bernoulli distributions.

\[ X_d = \begin{cases} Bern(p_1) \text{ with probability }\pi_1 \\ Bern(p_2) \text{ with probability }\pi_2 \\ \vdots \\ Bern(p_K) \text{ with probability }\pi_K \end{cases} \tag{2.1}\]

Let \(i\) be an index for each mixture that contributes to the random variable. The probability mass function of the random variable is written as

\[ P(X_d) = \Pi_{i = 1}^Kp_i^{x_i} (1 - p_i)^{1 - x_i} \tag{2.2}\]

Let’s consider a classic example from Bishop (2006) and Murphy (2022). The example uses the MNIST database, which contains 70,000 handwritten digits. The digits are stored in 784 variables, from a 28 by 28 grid, with values ranging from 0 to 255, which indicates the darkness of the pixel.

To prepare the data, we divide each pixel by 255 and then turn the pixels into indicators with values under 0.5 as 0 and values over 0.5 as 1. Let’s read in the data and the visualize the first row of the data.

source(here::here("R", "visualize_digit.R"))

mnist <- read_csv(here::here("data", "mnist_binary.csv"))

glimpse(select(mnist, 1:10))
Rows: 60,000
Columns: 10
$ label    <dbl> 5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4…
$ pix_28_1 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_6 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_7 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_8 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_9 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…

Figure 2.1 shows the first four digits in the data after pre-processing.

visualize_digit(mnist, 1)
visualize_digit(mnist, 2)
visualize_digit(mnist, 3)
visualize_digit(mnist, 4)

(a) 5

(b) 0

(c) 4

(d) 1

Figure 2.1: First Four Digits

The digits are labelled in the MNIST data set but we will ignore the labels and use Bernoulli Mixture Modeling as an unsupervised method. We will treat each pixel as it’s own Bernoulli distribution and cluster observations based mixtures of 784 Bernoulli distributions. This means each cluster will contain \(784K\) parameters.

2.5.6 Two Digit Example

Let’s start with a simple example using just the digits “1” and “8”. We’ll use library(flexmix) by Leisch (2004). library(flexmix) is powerful but uses different syntax than we are used to.

  1. The function flexmix() expects a matrix.
  2. The formula expects the entire matrix on the left side of the ~.
  3. We specify the distribution used during the maximization (M) step with model = FLXMCmvbinary().
library(flexmix)
Warning: package 'flexmix' was built under R version 4.1.2
Loading required package: lattice
mnist_18 <- mnist |>
  filter(label %in% c("1", "8")) |>
  select(-label) |>
  as.matrix()

The starting assignments are random, so we set a seed.

set.seed(20230612)
mnist_18_clust <- flexmix(
  formula = mnist_18 ~ 1, 
  k = 2, 
  model = FLXMCmvbinary(), 
  control = list(iter.max = 100)
)

The MNIST data are already labelled, so we can compare our assignments to the labels if we convert the “soft assignments” to “hard assignments”. Note that most applications won’t have labels.

mnist |>
  filter(label %in% c("1", "8")) |>  
  bind_cols(cluster = mnist_18_clust@cluster) |>
  count(label, cluster)
# A tibble: 4 × 3
  label cluster     n
  <dbl>   <int> <int>
1     1       1   482
2     1       2  6260
3     8       1  5610
4     8       2   241

Figure 2.2 shows the estimated \(p_i\) for each pixel for each cluster. The following visualize the \(784K\) parameters that we estimated. It shows 784 \(p_i\) for \(k = 1\) and 784 \(p_i\) for \(k = 2\). We see that the estimated parameters closely resemble the digits.

means_18 <- rbind(
  t(parameters(mnist_18_clust, component = 1)),
  t(parameters(mnist_18_clust, component = 2))
) |>
  as_tibble() |>
  mutate(label = NA)
visualize_digit(means_18, 1)
visualize_digit(means_18, 2)

(a) 5

(b) 0

Figure 2.2: Estimated Parameters for Each Cluster

The BMM does a good job of labeling the digits and recovering the average shape of the digits.

2.5.7 Ten Digit Example

Let’s now consider an example that uses all 10 digits.

In most applications, we won’t know the number of latent variables. First, we sample 1,0001 digits and run the model with \(k = 2, 3, ..., 20\). We’ll calculate the BIC for each hyperparameter and pick the \(k\) with lowest BIC.

set.seed(20230613)
mnist_sample <- mnist |>
  slice_sample(n = 1000) |>
  select(-label) |>
  as.matrix()

steps <- stepFlexmix(
  formula = mnist_sample ~ 1, 
  model = FLXMCmvbinary(), 
  control = list(iter.max = 100, minprior = 0),
  k = 2:20, 
  nrep = 1
)

steps

\(k = 7\) provides the lowest BIC. This is probably because digits like 3 and 8 are very similar. Next, we run the BMM on the full data with \(k = 7\).

mnist_full <- mnist |>
  select(-label) |>
  as.matrix()

mnist_clust <- flexmix(
  formula = mnist_full ~ 1, 
  k = 7, 
  model = FLXMCmvbinary(), 
  control = list(iter.max = 200, minprior = 0)
)

The MNIST data are already labelled, so we can compare our assignments to the labels if we convert the “soft assignments” to “hard assignments”. Note that most applications won’t have labels.

mnist |>
  bind_cols(cluster = mnist_clust@cluster) |>
  count(label, cluster) |>
  arrange(cluster)
# A tibble: 69 × 3
   label cluster     n
   <dbl>   <int> <int>
 1     0       1     1
 2     1       1  6219
 3     2       1   157
 4     3       1   275
 5     4       1    72
 6     5       1    63
 7     6       1   262
 8     7       1   198
 9     8       1   365
10     9       1   123
# ℹ 59 more rows

?fig-bmm10-parameters shows the estimated \(p_i\) for each pixel for each cluster. It following visualize the \(784K\) parameters that we estimated. It shows 784 \(p_i\) for \(k = 1, 2, ..., 7\) clusters. We see that the estimated parameters closely resemble the digits.

means <- rbind(
  t(parameters(mnist_clust, component = 1)),
  t(parameters(mnist_clust, component = 2)),
  t(parameters(mnist_clust, component = 3)),
  t(parameters(mnist_clust, component = 4)),
  t(parameters(mnist_clust, component = 5)),
  t(parameters(mnist_clust, component = 6)),
  t(parameters(mnist_clust, component = 7))
) |>
  as_tibble() |>
  mutate(label = NA)

visualize_digit(means, 1)
visualize_digit(means, 2)
visualize_digit(means, 3)
visualize_digit(means, 4)
visualize_digit(means, 5)
visualize_digit(means, 6)
visualize_digit(means, 7)

Figure 2.3: Estimated Parameters for Each Cluster

Figure 2.4: Estimated Parameters for Each Cluster

Figure 2.5: Estimated Parameters for Each Cluster

Figure 2.6: Estimated Parameters for Each Cluster

Figure 2.7: Estimated Parameters for Each Cluster

Figure 2.8: Estimated Parameters for Each Cluster

Figure 2.9: Estimated Parameters for Each Cluster

The example with all digits doesn’t result in 10 distinct mixtures but it does a fairly good job of structuring finding structure in the data. Without labels and considering the variety of messy handwriting, this is a useful model.

2.5.8 More Examples?

State human services typology

employed vs. unemployed vs. retired? yes/no on types of income

shopping cart

url <- "https://koalaverse.github.io/homlr/data/my_basket.csv"
my_basket <- readr::read_csv(url)
Rows: 2000 Columns: 42
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (42): 7up, lasagna, pepsi, yop, red.wine, cheese, bbq, bulmers, mayonnai...

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

  1. This is solely to save computation time.↩︎